#' Format MR results for forest plot
#'
#' This function takes the results from [mr()] and is particularly useful
#' if the MR has been applied using multiple exposures and multiple outcomes. 
#' It creates a new data frame with the following:
#' \itemize{
#' \item Variables: exposure, outcome, category, outcome sample size, effect, upper ci, lower ci, pval, nsnp
#' \item only one estimate for each exposure-outcome
#' \item exponentiated effects if required
#' }
#'
#' By default it uses the [available_outcomes()] function to retrieve the study level characteristics for the outcome trait, 
#' including sample size and outcome category. 
#' This assumes the MR analysis was performed using outcome GWAS(s) contained in MR-Base.
#'  
#' If \code{ao_slc} is set to \code{TRUE} then the user must supply their own study level characteristics. 
#' This is useful when the user has supplied their own outcome GWAS results (i.e. they are not in MR-Base).  
#' 
#' @param mr_res Results from [mr()].
#' @param exponentiate Convert effects to OR? The default is `FALSE`.
#' @param single_snp_method Which of the single SNP methods to use when only 1 SNP was used to estimate the causal effect? The default is `"Wald ratio"`.
#' @param multi_snp_method Which of the multi-SNP methods to use when there was more than 1 SNPs used to estimate the causal effect? The default is `"Inverse variance weighted"`.
#' @param ao_slc Logical; retrieve sample size and subcategory using [available_outcomes()]. If set to `FALSE` `mr_res` must contain the following additional columns: `subcategory` and `sample_size`. 
#' @param priority Name of category to prioritise at the top of the forest plot. The default is `"Cardiometabolic"`.
#'
#' @export
#' @return data frame.
format_mr_results <- function(mr_res, exponentiate=FALSE, single_snp_method="Wald ratio", multi_snp_method="Inverse variance weighted", ao_slc=TRUE, priority="Cardiometabolic")
{
	# Get extra info on outcomes
	if(ao_slc) 
	{ 
		ao <- available_outcomes()
		ao$subcategory[ao$subcategory == "Cardiovascular"] <- "Cardiometabolic"
		ao$subcategory[ao$trait == "Type 2 diabetes"] <- "Cardiometabolic"
		names(ao)[names(ao) == "nsnp"]<-"nsnp.array"
	}

	dat <- subset(mr_res, (nsnp==1 & method==single_snp_method) | (nsnp > 1 & method == multi_snp_method))
	dat$index <- 1:nrow(dat)
	
	if(ao_slc)
	{ 
		dat <- merge(dat, ao, by.x="id.outcome", by.y="id")
	}
	dat <- dat[order(dat$b), ]

	# Create CIs
	dat$up_ci <- as.numeric(dat$b) + 1.96 * as.numeric(dat$se)
	dat$lo_ci <- as.numeric(dat$b) - 1.96 * as.numeric(dat$se)

	# Exponentiate?
	if(exponentiate)
	{
		dat$b <- exp(as.numeric(dat$b))
		dat$up_ci <- exp(dat$up_ci)
		dat$lo_ci <- exp(dat$lo_ci)
	}
	
	# Organise cats
	dat$subcategory <- as.factor(dat$subcategory)
	
	if(!ao_slc) #generate a simple trait column. this contains only the outcome name (ie excludes consortium and year from the outcome column generated by mr()). This step caters to the possibility that a user's results contain a mixture of results obtained via MR-Base and correspondence. The later won't be present in the MR-Base database. However, still need to split the outcome name into trait, year and consortium. 
	{

		dat$trait<-dat$outcome
		Pos<-grep("\\|\\|",dat$trait) #this indicates the outcome column was derived from data in MR-Base. Sometimes it wont look like this e.g. if the user has supplied their own outcomes
		if(sum(Pos)!=0)
		{
			Outcome<-dat$trait[Pos]
			Outcome<-unlist(strsplit(Outcome,split="\\|\\|"))
			Outcome<-Outcome[seq(1,length(Outcome),by=3)]
			Outcome<-trim(Outcome)
			dat$trait[Pos]<-Outcome
		}

	}
	dat <- data.frame(
		exposure = as.character(dat$exposure),
		outcome = as.character(dat$trait),
		category = as.character(dat$subcategory),
		effect = dat$b,
		up_ci = dat$up_ci,
		lo_ci = dat$lo_ci,
		nsnp = dat$nsnp,
		pval = dat$pval,
		sample_size = dat$sample_size,
		index = dat$index,
		stringsAsFactors = FALSE
	)

	# if(fix_capitals)
	# {
	# 	dat$exposure <- simple_cap(dat$exposure)
	# 	dat$outcome <- simple_cap(dat$outcome)
	# 	dat$category <- simple_cap(dat$category)
	# }

	# Fill in missing values
	exps <- unique(dat$exposure)
	dat <- plyr::ddply(dat, c("outcome"), function(x)
	{
		x <- plyr::mutate(x)
		nc <- ncol(x)
		missed <- exps[! exps %in% x$exposure]
		if(length(missed) >= 1)
		{
			out <- unique(x$outcome)
			ca <- unique(x$category)
			n <- unique(x$sample_size)
			md <- data.frame(exposure = missed, outcome=out, category=ca, sample_size=n, stringsAsFactors=FALSE)
			x <- plyr::rbind.fill(x, md)
		}
		return(x)
	})
	# dat <- dplyr::group_by(dat, outcome) %>%
	# 	dplyr::do({
	# 		x <- .
	# 		nc <- ncol(x)
	# 		missed <- exps[! exps %in% x$exposure]
	# 		if(length(missed) >= 1)
	# 		{
	# 			out <- unique(x$outcome)
	# 			ca <- unique(x$category)
	# 			n <- unique(x$sample_size)
	# 			md <- data.frame(exposure = missed, outcome=out, category=ca, sample_size=n, stringsAsFactors=FALSE)
	# 			x <- dplyr::bind_rows(x, md)
	# 		}
	# 		return(x)
	# 	}) %>% as.data.frame(stringsAsFactors=FALSE)

	dat <- dat[order(dat$index), ]


	dat <- dat[order(dat$outcome), ]

	stopifnot(length(priority) == 1)

	if(priority %in% dat$category)
	{

		temp1 <- subset(dat, category==priority)
		temp2 <- subset(dat, category=="Other")
		dat <- rbind(
			subset(dat, category==priority), 
			subset(dat, !category %in% c(priority,"Other")),
			subset(dat, category=="Other")
		)

	}


	return(dat)
}

#' Simple attempt at correcting string case
#'
#' @param x Character or array of character
#'
#' @keywords internal
#' @return Character or array of character
simple_cap <- function(x) {
	sapply(x, function(x){
		x <- tolower(x)
		s <- strsplit(x, " ")[[1]]
		paste(toupper(substring(s, 1,1)), substring(s, 2), sep="", collapse=" ")
	})
}

#' Trim function to remove leading and trailing blank spaces
#'
#' @param x Character or array of character
#'
#' @export
#' @return Character or array of character
trim <- function( x ) {
  gsub("(^[[:space:]]+|[[:space:]]+$)", "", x)
}


#' Create fixed width label
#'
#' @param n1 number
#' @param nom name
#'
#' @keywords internal
#' @return text
create_label <- function(n1, nom)
{
	len_n1 <- max(nchar(n1), na.rm=TRUE)
	n1_c <- formatC(n1, width=len_n1)

	l <- nchar(nom)
	len_nom <- max(l)
	p <- paste0("%-", len_nom, "s")
	nomp <- sprintf(p, nom)

	out <- paste0(n1_c, "    ", nomp)
	out <- factor(out, levels = unique(out))
	return(out)
}

#' A basic forest plot
#'
#' This function is used to create a basic forest plot.
#' It requires the output from [format_mr_results()].
#'
#' @param dat Output from [format_mr_results()].
#' @param section Which category in dat to plot. If `NULL` then prints everything.
#' @param colour_group Which exposure to plot. If `NULL` then prints everything grouping by colour.
#' @param colour_group_first The default is `TRUE`.
#' @param xlab x-axis label. Default=`NULL`.
#' @param bottom Show x-axis? Default=`FALSE`.
#' @param trans Transformation of x axis.
#' @param xlim x-axis limits.
#' @param threshold p-value threshold to use for colouring points by significance level. If `NULL` (default) then colour layer won't be applied.
#'
#' @return ggplot object
#' @keywords internal
forest_plot_basic <- function(dat, section=NULL, colour_group=NULL, colour_group_first=TRUE, xlab=NULL, bottom=TRUE, trans="identity", xlim=NULL, threshold=NULL)
{
	if(bottom)
	{
		text_colour <- ggplot2::element_text(colour="black")
		tick_colour <- ggplot2::element_line(colour="black")
		xlabname <- xlab
	} else {
		text_colour <- ggplot2::element_blank()
		tick_colour <- ggplot2::element_blank()
		xlabname <- NULL
	}

	# OR or log(OR)?
	# If CI are symmetric then log(OR)
	# Use this to guess where to put the null line
	null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

	# Change lab
	if(!is.null(xlim))
	{
		stopifnot(length(xlim) == 2)
		stopifnot(xlim[1] < xlim[2])
		dat$lo_ci <- pmax(dat$lo_ci, xlim[1], na.rm=TRUE)
		dat$up_ci <- pmin(dat$up_ci, xlim[2], na.rm=TRUE)
	}

	up <- max(dat$up_ci, na.rm=TRUE)
	lo <- min(dat$lo_ci, na.rm=TRUE)
	r <- up-lo
	lo_orig <- lo
	lo <- lo - r * 0.5

	if(!is.null(section))
	{
		dat <- subset(dat, category==section)
		main_title <- section
	} else {
		main_title <- ""
	}

	if(!is.null(colour_group))
	{
		dat <- subset(dat, exposure == colour_group)
		if(!is.null(threshold))
		{
			point_plot <- ggplot2::geom_point(size=2, aes(colour = pval < threshold))
		} else {
			point_plot <- ggplot2::geom_point(size=2)
		}
	} else {
		if(!is.null(threshold))
		{
			point_plot <- ggplot2::geom_point(ggplot2::aes(colour=exposure, shape = pval < threshold), size=2)
		} else {
			point_plot <- ggplot2::geom_point(ggplot2::aes(colour=exposure), size=2)
		}
	}

	if((!is.null(colour_group) & colour_group_first) | is.null(colour_group))
	{
		outcome_labels <- ggplot2::geom_text(ggplot2::aes(label=outcome), x=lo, y=mean(c(1, length(unique(dat$exposure)))), hjust=0, vjust=0.5, size=2.5)
		main_title <- ifelse(is.null(section), "", section)
		title_colour <- "black"
	} else {
		outcome_labels <- NULL
		lo <- lo_orig
		main_title <- ""
		title_colour <- "white"
	}

	main_title <- section

	if(! "lab" %in% names(dat))
	{
		dat$lab <- create_label(dat$sample_size, dat$outcome)
	}

	l <- data.frame(lab=sort(unique(dat$lab)), col="a", stringsAsFactors=FALSE)
	l$col[1:nrow(l) %% 2 == 0] <- "b"

	dat <- merge(dat, l, by="lab", all.x=TRUE)
	dat <- dat[nrow(dat):1, ]

	p <- ggplot2::ggplot(dat, ggplot2::aes(x=effect, y=exposure)) +
	ggplot2::geom_rect(ggplot2::aes(fill=col), xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=Inf) +
	ggplot2::geom_vline(xintercept=seq(ceiling(lo_orig), ceiling(up), by=0.5), colour="white", size=0.3) +
	ggplot2::geom_vline(xintercept=null_line, colour="#333333", size=0.3) +
	ggplot2::geom_errorbarh(ggplot2::aes(xmin=lo_ci, xmax=up_ci), height=0, size=0.4, colour="#aaaaaa") +
	ggplot2::geom_point(colour="black", size=2.2) +
	point_plot +
	ggplot2::facet_grid(lab ~ .) +
	ggplot2::scale_x_continuous(trans=trans, limits=c(lo, up)) +
	ggplot2::scale_colour_brewer(type="qual") +
	ggplot2::scale_fill_manual(values=c("#eeeeee", "#ffffff"), guide=FALSE) +
	ggplot2::theme(
		axis.line=ggplot2::element_blank(),
		axis.text.y=ggplot2::element_blank(), 
		axis.ticks.y=ggplot2::element_blank(), 
		axis.text.x=text_colour, 
		axis.ticks.x=tick_colour, 
		# strip.text.y=ggplot2::element_text(angle=360, hjust=0), 
		strip.background=ggplot2::element_rect(fill="white", colour="white"),
		strip.text=ggplot2::element_text(family="Courier New", face="bold", size=9),
		legend.position="none",
		legend.direction="vertical",
		panel.grid.minor.x=ggplot2::element_blank(),
		panel.grid.minor.y=ggplot2::element_blank(),
		panel.grid.major.y=ggplot2::element_blank(),
		plot.title = ggplot2::element_text(hjust = 0, size=12, colour=title_colour),
		plot.margin=ggplot2::unit(c(2,3,2,0), units="points"),
		plot.background=ggplot2::element_rect(fill="white"),
		panel.spacing=ggplot2::unit(0,"lines"),
		panel.background=ggplot2::element_rect(colour="red", fill="grey", size=1),
		strip.text.y = ggplot2::element_blank()
		# strip.background = ggplot2::element_blank()
	) +
	ggplot2::labs(y=NULL, x=xlabname, colour="", fill=NULL, title=main_title) +
	outcome_labels
	return(p)
}


forest_plot_names <- function(dat, section=NULL, bottom=TRUE)
{
	if(bottom)
	{
		text_colour <- ggplot2::element_text(colour="white")
		tick_colour <- ggplot2::element_line(colour="white")
		xlabname <- ""
	} else {
		text_colour <- ggplot2::element_blank()
		tick_colour <- ggplot2::element_blank()
		xlabname <- NULL
	}

	# OR or log(OR)?
	# If CI are symmetric then log(OR)
	# Use this to guess where to put the null line
	null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

	# up <- max(dat$up_ci, na.rm=TRUE)
	# lo <- min(dat$lo_ci, na.rm=TRUE)
	# r <- up-lo
	# lo_orig <- lo
	# lo <- lo - r * 0.5
	lo <- 0
	up <- 1

	if(!is.null(section))
	{
		dat <- subset(dat, category==section)
		main_title <- section
		section_colour <- "black"
	} else {
		main_title <- section
		section_colour <- "white"
	}

	point_plot <- ggplot2::geom_point(ggplot2::aes(colour=exposure), size=2)

	outcome_labels <- ggplot2::geom_text(
		ggplot2::aes(label=outcome), 
		x=lo, 
		y=mean(c(1, length(unique(dat$exposure)))), 
		hjust=0, vjust=0.5, size=3.5
	)
	main_title <- section

	if(! "lab" %in% names(dat))
	{
		dat$lab <- create_label(dat$sample_size, dat$outcome)
	}

	l <- data.frame(lab=sort(unique(dat$lab)), col="a", stringsAsFactors=FALSE)
	l$col[1:nrow(l) %% 2 == 0] <- "b"

	dat <- merge(dat, l, by="lab", all.x=TRUE)

	p <- ggplot2::ggplot(dat, ggplot2::aes(x=effect, y=exposure)) +
	ggplot2::geom_rect(ggplot2::aes(fill=col), xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=Inf) +
	ggplot2::facet_grid(lab ~ .) +
	ggplot2::scale_x_continuous(limits=c(lo, up)) +
	ggplot2::scale_colour_brewer(type="qual") +
	ggplot2::scale_fill_manual(values=c("#eeeeee", "#ffffff"), guide=FALSE) +
	ggplot2::theme(
		axis.line=ggplot2::element_blank(),
		axis.text.y=ggplot2::element_blank(), 
		axis.ticks.y=ggplot2::element_blank(), 
		axis.text.x=text_colour, 
		axis.ticks.x=tick_colour, 
		# strip.text.y=ggplot2::element_text(angle=360, hjust=0), 
		strip.background=ggplot2::element_rect(fill="white", colour="white"),
		strip.text=ggplot2::element_text(family="Courier New", face="bold", size=11),
		legend.position="none",
		legend.direction="vertical",
		panel.grid.minor.x=ggplot2::element_blank(),
		panel.grid.minor.y=ggplot2::element_blank(),
		panel.grid.major.y=ggplot2::element_blank(),
		plot.title = ggplot2::element_text(hjust = 0, size=12, colour=section_colour),
		plot.margin=ggplot2::unit(c(2,0,2,0), units="points"),
		plot.background=ggplot2::element_rect(fill="white"),
		panel.spacing=ggplot2::unit(0,"lines"),
		panel.background=ggplot2::element_rect(colour="red", fill="grey", size=1),
		strip.text.y = ggplot2::element_blank()
		# strip.background = ggplot2::element_blank()
	) +
	ggplot2::labs(y=NULL, x=xlabname, colour="", fill=NULL, title=main_title) +
	outcome_labels
	return(p)
}




#' Forest plot for multiple exposures and multiple outcomes
#'
#' Perform MR of multiple exposures and multiple outcomes. This plots the results.
#' 
#' @param mr_res Results from [mr()].
#' @param exponentiate Convert effects to OR? Default is `FALSE`.
#' @param single_snp_method Which of the single SNP methosd to use when only 1 SNP was used to estimate the causal effect? The default is `"Wald ratio"`.
#' @param multi_snp_method Which of the multi-SNP methods to use when there was more than 1 SNPs used to estimate the causal effect? The default is `"Inverse variance weighted"`.
#' @param group_single_categories If there are categories with only one outcome, group them together into an "Other" group. The default is `TRUE`.
#' @param by_category Separate the results into sections by category? The default is `TRUE`.
#' @param in_columns Separate the exposures into different columns. The default is `FALSE`.
#' @param threshold p-value threshold to use for colouring points by significance level. If `NULL` (default) then colour layer won't be applied.
#' @param xlab x-axis label. If `in_columns=TRUE` then the exposure values are appended to the end of `xlab`. e.g. if `xlab="Effect of"` then x-labels will read `"Effect of exposure1"`, `"Effect of exposure2"` etc. Otherwise will be printed as is.
#' @param xlim limit x-axis range. Provide vector of length 2, with lower and upper bounds. The default is `NULL`.
#' @param trans Transformation to apply to x-axis. e.g. `"identity"`, `"log2"`, etc. The default is `"identity"`.
#' @param ao_slc retrieve sample size and subcategory from [available_outcomes()]. If set to `FALSE` then `mr_res` must contain the following additional columns: `sample_size` and `subcategory`. The default behaviour is to use [available_outcomes()] to retrieve sample size and subcategory.
#' @param priority Name of category to prioritise at the top of the forest plot. The default is `"Cardiometabolic"`.
#'
#' @export
#' @return grid plot object
forest_plot <- function(mr_res, exponentiate=FALSE, single_snp_method="Wald ratio", multi_snp_method="Inverse variance weighted", group_single_categories=TRUE, by_category=TRUE, in_columns=FALSE, threshold=NULL, xlab="", xlim=NULL, trans="identity",ao_slc=TRUE, priority="Cardiometabolic")
{
	dat <- format_mr_results(
		mr_res, 
		exponentiate=exponentiate, 
		single_snp_method=single_snp_method,
		multi_snp_method=multi_snp_method,
		# group_single_categories=group_single_categories,
		ao_slc=ao_slc,
		priority=priority
	)
	if(group_single_categories)
	{
		temp <- subset(dat, !duplicated(outcome))
		tab <- table(temp$subcategory)
		othercats <- names(tab)[tab==1]
		if(length(othercats) > 1)
		{
			levels(dat$subcategory)[levels(dat$subcategory) %in% othercats] <- "Other"
			dat$subcategory <- factor(dat$subcategory)
		}
	}

	dat$lab <- create_label(dat$sample_size, dat$outcome)

	legend <- cowplot::get_legend(
		ggplot2::ggplot(dat, ggplot2::aes(x=effect, y=outcome)) + 
		ggplot2::geom_point(ggplot2::aes(colour=exposure)) + 
		ggplot2::scale_colour_brewer(type="qual") + 
		ggplot2::labs(colour="Exposure") + 
		ggplot2::theme(text=ggplot2::element_text(size=10))
	)

	if(!by_category)
	{
		if(!in_columns)
		{
			return(forest_plot_basic(
				dat, 
				bottom = TRUE, 
				xlab = xlab, 
				trans = trans,
				xlim = xlim,
				threshold = threshold
			) + ggplot2::theme(legend.position="left"))
		} else {
			l <- list()
			l[[1]] <- forest_plot_names(
				dat, 
				section = NULL, 
				bottom = TRUE
			)
			count <- 2
			columns <- unique(dat$exposure)
			for(i in 1:length(columns))
			{
				l[[count]] <- forest_plot_basic(
					dat, 
					section=NULL,
					bottom = TRUE, 
					colour_group=columns[i], 
					colour_group_first = FALSE, 
					xlab = paste0(xlab, " ", columns[i]), 
					trans = trans,
					xlim = xlim,
					threshold = threshold
				)
				count <- count + 1
			}
			return(
				cowplot::plot_grid(
					gridExtra::arrangeGrob(
						grobs=l, 
						ncol=length(columns) + 1, 
						nrow=1, 
						widths=c(4, rep(5, length(columns)))
					)
				)
			)
		}
	}

	if(!in_columns)
	{
		sec <- unique(as.character(dat$category))
		h <- rep(0, length(sec))
		l <- list()
		for(i in 1:length(sec))
		{
			l[[i]] <- forest_plot_basic(
				dat, 
				sec[i], 
				bottom = i==length(sec), 
				xlab = xlab, 
				trans = trans,
				xlim = xlim,
				threshold = threshold
			)
			h[i] <- length(unique(subset(dat, category==sec[i])$outcome))
		}
		h <- h + 1
		h[length(sec)] <- h[length(sec)] + 1

		return(
			cowplot::plot_grid(
				gridExtra::arrangeGrob(
					legend,
					gridExtra::arrangeGrob(grobs=l, ncol=1, nrow=length(h), heights=h),
					ncol=2, nrow=1, widths=c(1, 5) 
				)
			)
		)
	} else {
		sec <- unique(as.character(dat$category))
		columns <- unique(dat$exposure)
		l <- list()
		h <- rep(0, length(sec))
		count <- 1
		for(i in 1:length(sec))
		{
			h[i] <- length(unique(subset(dat, category==sec[i])$outcome))
			l[[count]] <- forest_plot_names(
				dat, 
				sec[i], 
				bottom = i==length(sec)
			)
			count <- count + 1
			for(j in 1:length(columns))
			{
				l[[count]] <- forest_plot_basic(
					dat, 
					sec[i], 
					bottom = i==length(sec), 
					colour_group=columns[j], 
					colour_group_first = FALSE, 
					xlab = paste0(xlab, " ", columns[j]), 
					trans = trans,
					xlim = xlim,
					threshold = threshold
				)
				count <- count + 1
			}
		}
		h <- h + 1
		h[length(sec)] <- h[length(sec)] + 1

		return(
			cowplot::plot_grid(
				gridExtra::arrangeGrob(
					grobs=l, 
					ncol=length(columns) + 1, 
					nrow=length(h), 
					heights=h,
					widths=c(4, rep(5, length(columns)))
				)
			)
		)
	}
}

